package certmonitor

import (
	"context"
	"crypto/x509"
	"fmt"
	"os"
	"path/filepath"
	"strings"
	"time"

	daemonconfig "github.com/k3s-io/k3s/pkg/daemons/config"
	"github.com/k3s-io/k3s/pkg/daemons/control/deps"
	"github.com/k3s-io/k3s/pkg/metrics"
	"github.com/k3s-io/k3s/pkg/util"
	"github.com/k3s-io/k3s/pkg/util/services"
	"github.com/k3s-io/k3s/pkg/version"
	"github.com/prometheus/client_golang/prometheus"
	certutil "github.com/rancher/dynamiclistener/cert"
	"github.com/rancher/wrangler/v3/pkg/merr"
	"github.com/sirupsen/logrus"
	corev1 "k8s.io/api/core/v1"
	metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
	"k8s.io/apimachinery/pkg/types"
	"k8s.io/apimachinery/pkg/util/wait"
)

var (
	// Check certificates twice an hour. Kubernetes events have a TTL of 1 hour by default,
	// so similar events should be aggregated and refreshed by the event recorder as long
	// as they are created within the TTL period.
	certCheckInterval = time.Minute * 30

	controllerName = version.Program + "-cert-monitor"

	certificateExpirationSeconds = prometheus.NewGaugeVec(prometheus.GaugeOpts{
		Name: version.Program + "_certificate_expiration_seconds",
		Help: "Remaining lifetime on the certificate.",
	}, []string{"subject", "usages"})
)

// Setup starts the certificate expiration monitor
func Setup(ctx context.Context, nodeConfig *daemonconfig.Node, dataDir string) error {
	logrus.Debugf("Starting %s with monitoring period %s", controllerName, certCheckInterval)
	metrics.DefaultRegisterer.MustRegister(certificateExpirationSeconds)

	client, err := util.GetClientSet(nodeConfig.AgentConfig.KubeConfigKubelet)
	if err != nil {
		return err
	}

	recorder := util.BuildControllerEventRecorder(client, controllerName, metav1.NamespaceDefault)

	// This is consistent with events attached to the node generated by the kubelet
	// https://github.com/kubernetes/kubernetes/blob/612130dd2f4188db839ea5c2dea07a96b0ad8d1c/pkg/kubelet/kubelet.go#L479-L485
	nodeRef := &corev1.ObjectReference{
		Kind:      "Node",
		Name:      nodeConfig.AgentConfig.NodeName,
		UID:       types.UID(nodeConfig.AgentConfig.NodeName),
		Namespace: "",
	}

	// Create a dummy controlConfig just to hold the paths for the server certs
	controlConfig := daemonconfig.Control{
		DataDir: filepath.Join(dataDir, "server"),
		Runtime: &daemonconfig.ControlRuntime{},
	}
	deps.CreateRuntimeCertFiles(&controlConfig)

	caMap := map[string][]string{}
	nodeList := services.Agent
	if _, err := os.Stat(controlConfig.DataDir); err == nil {
		nodeList = services.All
		caMap, err = services.FilesForServices(controlConfig, services.CA)
		if err != nil {
			return err
		}
	}

	nodeMap, err := services.FilesForServices(controlConfig, nodeList)
	if err != nil {
		return err
	}

	go wait.Until(func() {
		logrus.Debugf("Running %s certificate expiration check", controllerName)
		if err := checkCerts(nodeMap, time.Hour*24*daemonconfig.CertificateRenewDays); err != nil {
			message := fmt.Sprintf("Node certificates require attention - restart %s on this node to trigger automatic rotation: %v", version.Program, err)
			recorder.Event(nodeRef, corev1.EventTypeWarning, "CertificateExpirationWarning", message)
		}
		if err := checkCerts(caMap, time.Hour*24*365); err != nil {
			message := fmt.Sprintf("Certificate authority certificates require attention - check %s documentation and begin planning rotation: %v", version.Program, err)
			recorder.Event(nodeRef, corev1.EventTypeWarning, "CACertificateExpirationWarning", message)

		}
	}, certCheckInterval, ctx.Done())

	return nil
}

func checkCerts(fileMap map[string][]string, warningPeriod time.Duration) error {
	errs := merr.Errors{}
	now := time.Now()
	warn := now.Add(warningPeriod)

	for service, files := range fileMap {
		for _, file := range files {
			basename := filepath.Base(file)
			certs, _ := certutil.CertsFromFile(file)
			for _, cert := range certs {
				usages := []string{}
				if cert.KeyUsage&x509.KeyUsageCertSign != 0 {
					usages = append(usages, "CertSign")
				}
				for _, eku := range cert.ExtKeyUsage {
					switch eku {
					case x509.ExtKeyUsageServerAuth:
						usages = append(usages, "ServerAuth")
					case x509.ExtKeyUsageClientAuth:
						usages = append(usages, "ClientAuth")
					}
				}
				certificateExpirationSeconds.WithLabelValues(cert.Subject.String(), strings.Join(usages, ",")).Set(cert.NotAfter.Sub(now).Seconds())
				if now.Before(cert.NotBefore) {
					errs = append(errs, fmt.Errorf("%s/%s: certificate %s is not valid before %s", service, basename, cert.Subject, cert.NotBefore.Format(time.RFC3339)))
				} else if now.After(cert.NotAfter) {
					errs = append(errs, fmt.Errorf("%s/%s: certificate %s expired at %s", service, basename, cert.Subject, cert.NotAfter.Format(time.RFC3339)))
				} else if warn.After(cert.NotAfter) {
					errs = append(errs, fmt.Errorf("%s/%s: certificate %s will expire within %d days at %s", service, basename, cert.Subject, int(warningPeriod.Hours()/24), cert.NotAfter.Format(time.RFC3339)))
				}
			}
		}
	}

	return merr.NewErrors(errs...)
}
